Unroll _transform_tuple to fix Enzyme.autodiff on tuples of length ≥ 33#170
Unroll _transform_tuple to fix Enzyme.autodiff on tuples of length ≥ 33#170jlperla wants to merge 1 commit into
Conversation
The recursive Base.tail fold in _transform_tuple makes Enzyme.autodiff
(Forward and Reverse) throw `AssertionError("conv == 37")` from
Enzyme/src/rules/jitrules.jl:2073 once the tuple has ≥ 33 entries
(EnzymeAD/Enzyme.jl#3104). Replace it with a @generated straight-line
unroll that produces the same outputs bit-for-bit while emitting no
self-invoke in the typed IR — which is what Enzyme trips on.
Verified against the full Pkg.test() suite (all Pass = Total) and a
35-entry SW07-Pfeifer-style NamedTuple prior (fwd + rev both succeed).
|
Any chance that this will be merged here? I understand that the real fix should be on Enzyme's side, but that may be much harder. Thanks! PS: The is my real world MWE that lead me finally to this PR; maybe it is useful for someone. using Distributions
using Enzyme
using TransformVariables
N = 33
dists = ntuple(i -> LogNormal(0.0, 1.0), N)
dists = NamedTuple{ntuple(i -> Symbol("x", i), N)}(dists)
function prior_transform(priors)
transforms = map(priors) do prior
left, right = extrema(support(prior))
left = isinf(left) ? -TransformVariables.∞ : left
right = isinf(right) ? TransformVariables.∞ : right
TransformVariables.as(Real, left, right)
end
TransformVariables.as(transforms)
end
trans = prior_transform(dists)
q = fill(-0.1, TransformVariables.dimension(trans))
foo(q) = sum(values(TransformVariables.transform(trans, q)))
Enzyme.gradient(Enzyme.Reverse, foo, q) # AssertionError: conv == 37 |
|
@jlperla, thanks for this, @scheidan, thanks for the ping. I apologize for the delay in reviewing this. It is not strictly equivalent as, AFAIK, built-ins do not necessarily unroll above a certain tuple length. But given that the intention of using a tuple is to get type-stable code, I don't see a problem with this here. Also, EnzymeAD/Enzyme.jl#3104 indicates that this is an issue on the Julia side, so fixing it on our end may be the best option for now. @devmotion, this is fine with me, do you have any comments? |
|
(closing and reopening to make CI run) |
| Implemented as a `@generated` straight-line unroll over the static tuple length. | ||
| Equivalent to the natural `Base.tail` recursion, but emits non-recursive code | ||
| so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on | ||
| tuples of length ≥ 33 (EnzymeAD/Enzyme.jl#3104). |
There was a problem hiding this comment.
This seems very internal for being part of a docstrings? It also might change again in case of upstream compiler or Enzyme changes.
| Implemented as a `@generated` straight-line unroll over the static tuple length. | |
| Equivalent to the natural `Base.tail` recursion, but emits non-recursive code | |
| so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on | |
| tuples of length ≥ 33 (EnzymeAD/Enzyme.jl#3104). |
There was a problem hiding this comment.
This is an internal helper function anyway and not part of the API. (I like to document my internal functions too, I know this is not common to do so). As far as I am concerned this is fine.
| for i in 1:N] | ||
| ℓ_sum = foldl((a, b) -> :($a + $b), ℓs) | ||
| return quote | ||
| idx = index |
There was a problem hiding this comment.
Why is a separate idx variable needed? Couldn't we just operate with index?
Replace the
Base.tail-recursive_transform_tuplewith a@generatedstraight-line unroll — same outputs bit-for-bit, but the typed IR no longer contains a self-invoke, which is whatEnzyme.autodiff(Forward and Reverse) trips on at tuple length ≥ 33 withAssertionError("conv == 37")(EnzymeAD/Enzyme.jl#3104).